Skip to content

Add weight tying support for Llama3#2580

Merged
tianyu-l merged 3 commits intopytorch:mainfrom
dean-mccoppin:feat/llama3-weight-tying
Mar 24, 2026
Merged

Add weight tying support for Llama3#2580
tianyu-l merged 3 commits intopytorch:mainfrom
dean-mccoppin:feat/llama3-weight-tying

Conversation

@dean-mccoppin
Copy link
Copy Markdown
Contributor

Implements enable_weight_tying for Llama3, sharing tok_embeddings.weight with output.weight. It mirrors the Qwen3 implementation from #1590 (thanks!)

Changes cover model.py (config field, tying in init/init_weights, PP guard), parallelize.py (grouped FSDP unit for tied params), state_dict_adapter.py (skip/reconstruct output.weight for HF conversion), and a new unit test file

Closes #1524

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 15, 2026
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the existing model registry doesn't have llama3.2 1B / 3B models, which are the only variants which have weight-tying enabled. Please add those models to llama3/__init__.py. You can refer to the exact config in earlier attempt #1376

Ties tok_embeddings.weight to output.weight via enable_weight_tying config flag.
Follows the same pattern as Qwen3 (pytorch#1590).

Closes pytorch#1524.
Llama 3.2 1B and 3B are the only Llama variants with weight tying, so
they belong in the registry. Without them the feature has no real entry
point.

Also dropped the try/except guard in test_weight_tying.py, which was
inconsistent with every other unit test here and silently skips on
broken imports.
@dean-mccoppin dean-mccoppin force-pushed the feat/llama3-weight-tying branch from 8d7f787 to ceda986 Compare March 20, 2026 20:06
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

@tianyu-l
Copy link
Copy Markdown
Contributor

please fix tests

@tianyu-l tianyu-l merged commit 8953d2e into pytorch:main Mar 24, 2026
10 of 11 checks passed
pytorch-bot bot pushed a commit that referenced this pull request Mar 27, 2026
Implements enable_weight_tying for Llama3, sharing tok_embeddings.weight
with output.weight. It mirrors the Qwen3 implementation from #1590
(thanks!)

Changes cover model.py (config field, tying in __init__/init_weights, PP
guard), parallelize.py (grouped FSDP unit for tied params),
state_dict_adapter.py (skip/reconstruct output.weight for HF
conversion), and a new unit test file

Closes #1524
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Weight tying between embedding and LM head layer

2 participants